import base64

import numpy as np
import pandas as pd
import torch
from torch.utils.data import IterableDataset, get_worker_info

from torchvision.io import decode_image, ImageReadMode


class ImageDataset(IterableDataset):
    def __init__(self, fp_imgs: str):
        self.fp_imgs = fp_imgs

    def __iter__(self):
        worker_info = get_worker_info()
        with pd.read_json(
            self.fp_imgs, lines=True, chunksize=worker_info.num_workers
        ) as reader:

            for chunk in reader:
                img = chunk.iloc[worker_info.id]
                img_data = np.frombuffer(
                    base64.b64decode(img["img_b64"]), dtype="uint8"
                ).copy()
                img_data = torch.from_numpy(img_data)
                img_data = decode_image(img_data, mode=ImageReadMode.RGB)

                yield img["qry_id"], img["img_rank"], img_data
